{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 分布式训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "sys.version_info(major=3, minor=6, micro=10, releaselevel='final', serial=0)\n",
      "matplotlib 3.1.2\n",
      "numpy 1.18.1\n",
      "pandas 0.25.3\n",
      "sklearn 0.22.1\n",
      "tensorflow 2.0.0\n",
      "tensorflow_core.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "# 导入\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl,np,pd,sklearn,tf,keras:\n",
    "    print(module.__name__,module.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "tf.debugging.set_log_device_placement(True)\n",
    "\n",
    "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "# for gpu in gpus:\n",
    "#     tf.config.experimental.set_memory_growth(gpu, True)\n",
    "tf.config.experimental.set_virtual_device_configuration(\n",
    "    gpus[0],\n",
    "    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000),\n",
    "     tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)]\n",
    ")\n",
    "print(len(gpus))\n",
    "\n",
    "logical_gpus = tf.config.experimental.list_logical_devices('GPU')\n",
    "print(len(logical_gpus))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 28, 28) (5000,)\n",
      "(55000, 28, 28) (55000,)\n",
      "(10000, 28, 28) (10000,)\n"
     ]
    }
   ],
   "source": [
    "fashion_mnist = keras.datasets.fashion_mnist\n",
    "(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()\n",
    "x_valid,x_train = x_train_all[:5000],x_train_all[5000:]\n",
    "y_valid,y_train = y_train_all[:5000],y_train_all[5000:]\n",
    "\n",
    "print(x_valid.shape,y_valid.shape)\n",
    "print(x_train.shape,y_train.shape)\n",
    "print(x_test.shape,y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "scaler = StandardScaler()\n",
    "x_train_scaled = scaler.fit_transform(\n",
    "    x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)\n",
    "x_valid_scaled = scaler.transform(\n",
    "    x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)\n",
    "x_test_scaled = scaler.transform(\n",
    "    x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op AnonymousRandomSeedGenerator in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op ShuffleDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n"
     ]
    }
   ],
   "source": [
    "def make_dataset(images, labels, epochs, batch_size, shuffle=True):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((images, labels))\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)\n",
    "    return dataset\n",
    "\n",
    "# batch_size = 128\n",
    "batch_size_per_replica = 128\n",
    "batch_size = batch_size_per_replica * len(logical_gpus)\n",
    "epochs = 100\n",
    "train_dataset = make_dataset(x_train_scaled, y_train, epochs, batch_size)\n",
    "valid_dataset = make_dataset(x_valid_scaled, y_valid, epochs, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Keras model上进行分布式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Some requested devices in `tf.distribute.Strategy` are not visible to TensorFlow: /job:localhost/replica:0/task:0/device:GPU:0,/job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op RandomUniform in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Sub in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Mul in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Add in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op Fill in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    }
   ],
   "source": [
    "strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "with strategy.scope():\n",
    "    model = keras.models.Sequential()\n",
    "    model.add(keras.layers.Conv2D(filters=32, kernel_size=3, padding='same',\n",
    "                                  activation='relu', input_shape=(28,28,1)))\n",
    "    model.add(keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "    model.add(keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "    model.add(keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "    model.add(keras.layers.Flatten())\n",
    "    model.add(keras.layers.Dense(128, activation='relu'))\n",
    "    model.add(keras.layers.Dense(10, activation='softmax'))\n",
    "\n",
    "    # 模型编译，固化模型\n",
    "    model.compile(loss='sparse_categorical_crossentropy',\n",
    "                  optimizer='adam',\n",
    "                  metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 28, 28, 32)        320       \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 28, 28, 32)        9248      \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496     \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 14, 14, 64)        36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 7, 7, 128)         73856     \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 7, 7, 128)         147584    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 (None, 3, 3, 128)         0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 1152)              0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               147584    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 435,306\n",
      "Trainable params: 435,306\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练\n",
    "history = model.fit(\n",
    "    train_dataset, steps_per_epoch = x_train_scaled.shape[0] // batch_size, epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimator上进行分布式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op RandomUniform in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Sub in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Mul in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Add in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Fill in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.Sequential()\n",
    "model.add(keras.layers.Conv2D(filters=32, kernel_size=3, padding='same',\n",
    "                              activation='relu', input_shape=(28,28,1)))\n",
    "model.add(keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "model.add(keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "model.add(keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "model.add(keras.layers.Flatten())\n",
    "model.add(keras.layers.Dense(128, activation='relu'))\n",
    "model.add(keras.layers.Dense(10, activation='softmax'))\n",
    "\n",
    "# 模型编译，固化模型\n",
    "model.compile(loss='sparse_categorical_crossentropy',\n",
    "              optimizer='adam',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Some requested devices in `tf.distribute.Strategy` are not visible to TensorFlow: /job:localhost/replica:0/task:0/device:GPU:1,/job:localhost/replica:0/task:0/device:GPU:0\n",
      "INFO:tensorflow:Initializing RunConfig with distribution strategies.\n",
      "INFO:tensorflow:Not using Distribute Coordinator.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: C:\\Users\\Awebone\\AppData\\Local\\Temp\\tmpui5k1w17\n",
      "INFO:tensorflow:Using the Keras model provided.\n",
      "Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "WARNING:tensorflow:From f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n",
      "INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\Awebone\\\\AppData\\\\Local\\\\Temp\\\\tmpui5k1w17', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x0000020B0F27B668>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x0000020B0F27B898>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}\n"
     ]
    }
   ],
   "source": [
    "# 进行分布式\n",
    "strategy = tf.distribute.MirroredStrategy()\n",
    "config = tf.estimator.RunConfig(train_distribute = strategy)\n",
    "estimator = keras.estimator.model_to_estimator(model, config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 28, 28, 32)        320       \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 28, 28, 32)        9248      \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496     \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 14, 14, 64)        36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 7, 7, 128)         73856     \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 7, 7, 128)         147584    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 (None, 3, 3, 128)         0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 1152)              0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               147584    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 435,306\n",
      "Trainable params: 435,306\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.train(\n",
    "    input_fn = lambda : make_dataset(x_train_scaled, y_train, epochs, batch_size),\n",
    "    max_steps = 5000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义训练流程进行分布式\n",
    "### 单机版的自定义训练流程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op AnonymousRandomSeedGenerator in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op ShuffleDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n"
     ]
    }
   ],
   "source": [
    "def make_dataset(images, labels, epochs, batch_size, shuffle=True):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((images, labels))\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)\n",
    "    return dataset\n",
    "\n",
    "batch_size = 128\n",
    "train_dataset = make_dataset(x_train_scaled, y_train, 1, batch_size)\n",
    "valid_dataset = make_dataset(x_valid_scaled, y_valid, 1, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op RandomUniform in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Sub in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Mul in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Add in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Fill in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.Sequential()\n",
    "model.add(keras.layers.Conv2D(filters=32, kernel_size=3, padding='same',\n",
    "                              activation='relu', input_shape=(28,28,1)))\n",
    "model.add(keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "model.add(keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "model.add(keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))\n",
    "model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "model.add(keras.layers.Flatten())\n",
    "model.add(keras.layers.Dense(128, activation='relu'))\n",
    "model.add(keras.layers.Dense(10, activation='softmax'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 28, 28, 32)        320       \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 28, 28, 32)        9248      \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496     \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 14, 14, 64)        36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 7, 7, 128)         73856     \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 7, 7, 128)         147584    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 (None, 3, 3, 128)         0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 1152)              0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               147584    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 435,306\n",
      "Trainable params: 435,306\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "customized training loop:\n",
    "1. define losses functions\n",
    "2. define function train_step\n",
    "3. define function test_step\n",
    "4. for-loop training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op OptimizeDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op ModelDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op IteratorGetNextSync in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op __inference_initialize_variables_629 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op __inference_train_step_845 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AddV2 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op RealDiv in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "total_loss:608.479, num_batches:429.000, average_loss:1.418, time:0.000Executing op __inference_train_step_3636 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "total_loss:609.399, num_batches:430.000, average_loss:1.417, time:2.990Executing op DeleteIterator in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op __inference_test_step_3755 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op __inference_test_step_3938 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op DivNoNan in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Epoch:1, Loss:1.417, Acc:0.542, Val_Loss:0.743, Val_Acc:0.728\n",
      "Epoch:2, Loss:0.655, Acc:0.759, Val_Loss:0.540, Val_Acc:0.809time:0.003\n",
      "Epoch:3, Loss:0.527, Acc:0.806, Val_Loss:0.469, Val_Acc:0.837time:0.000\n",
      "Epoch:4, Loss:0.458, Acc:0.831, Val_Loss:0.432, Val_Acc:0.849time:0.000\n",
      "Epoch:5, Loss:0.411, Acc:0.849, Val_Loss:0.395, Val_Acc:0.858time:0.009\n",
      "Epoch:6, Loss:0.386, Acc:0.858, Val_Loss:0.368, Val_Acc:0.871time:0.005\n",
      "Epoch:7, Loss:0.363, Acc:0.868, Val_Loss:0.339, Val_Acc:0.874time:0.010\n",
      "Epoch:8, Loss:0.345, Acc:0.873, Val_Loss:0.357, Val_Acc:0.870time:0.010\n",
      "Epoch:9, Loss:0.329, Acc:0.880, Val_Loss:0.328, Val_Acc:0.878time:0.008\n",
      "Epoch:10, Loss:0.318, Acc:0.883, Val_Loss:0.303, Val_Acc:0.889ime:0.007\n"
     ]
    }
   ],
   "source": [
    "# 1.define losses functions\n",
    "loss_func = keras.losses.SparseCategoricalCrossentropy(\n",
    "    reduction = keras.losses.Reduction.SUM_OVER_BATCH_SIZE)\n",
    "test_loss = keras.metrics.Mean(name = 'test_loss')\n",
    "train_accuracy = keras.metrics.SparseCategoricalAccuracy(name = 'train_accuracy')\n",
    "test_accuracy = keras.metrics.SparseCategoricalAccuracy(name = 'test_accuracy')\n",
    "\n",
    "optimizer = keras.optimizers.SGD(lr = 0.01)\n",
    "\n",
    "# 有图计算：正向计算和反向传播，所以添加tf.function加速\n",
    "# 2.define function train_step\n",
    "@tf.function\n",
    "def train_step(inputs):\n",
    "    images, labels = inputs\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = model(images, training = True)\n",
    "        loss = loss_func(labels, predictions)\n",
    "    gradients = tape.gradient(loss, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "    train_accuracy.update_state(labels, predictions)\n",
    "    return loss\n",
    "\n",
    "# 3.define function test_step\n",
    "@tf.function\n",
    "def test_step(inputs):\n",
    "    images, labels = inputs\n",
    "    predictions = model(images)\n",
    "    t_loss = loss_func(labels, predictions)\n",
    "    test_loss.update_state(t_loss)\n",
    "    test_accuracy.update_state(labels, predictions)\n",
    "\n",
    "# 4.for-loop training loop\n",
    "epochs = 10\n",
    "for epoch in range(epochs):\n",
    "    total_loss = 0.0\n",
    "    num_batches = 0\n",
    "    # 在训练集训练\n",
    "    for x in train_dataset:\n",
    "        start_time = time.time()\n",
    "        total_loss += train_step(x)\n",
    "        run_time = time.time() - start_time # 计算时间\n",
    "        num_batches += 1\n",
    "        print('\\rtotal_loss:%3.3f, num_batches:%3.3f, average_loss:%3.3f, time:%3.3f' \n",
    "              % (total_loss, num_batches, total_loss / num_batches, run_time), end='')\n",
    "    train_loss = total_loss / num_batches\n",
    "    \n",
    "    # 验证集验证\n",
    "    for x in valid_dataset:\n",
    "        test_step(x)\n",
    "    \n",
    "    # 打印日志\n",
    "    print('\\rEpoch:%d, Loss:%3.3f, Acc:%3.3f, Val_Loss:%3.3f, Val_Acc:%3.3f' \n",
    "          % (epoch + 1, train_loss, train_accuracy.result(), \n",
    "             test_loss.result(), test_accuracy.result()))\n",
    "    # 清空累计值\n",
    "    test_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "    test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 分布式自定义训练流程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Some requested devices in `tf.distribute.Strategy` are not visible to TensorFlow: /job:localhost/replica:0/task:0/device:GPU:1,/job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op AnonymousRandomSeedGenerator in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op ShuffleDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op RebatchDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op AutoShardDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n"
     ]
    }
   ],
   "source": [
    "def make_dataset(images, labels, epochs, batch_size, shuffle=True):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((images, labels))\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)\n",
    "    return dataset\n",
    "\n",
    "strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "# 数据分布式，每个dataset对应每个GPU，加快数据传输效率\n",
    "with strategy.scope():\n",
    "    batch_size_per_replica = 128\n",
    "    batch_size = batch_size_per_replica * len(logical_gpus)\n",
    "    train_dataset = make_dataset(x_train_scaled, y_train, 1, batch_size)\n",
    "    valid_dataset = make_dataset(x_valid_scaled, y_valid, 1, batch_size)\n",
    "    train_dataset_distribute = strategy.experimental_distribute_dataset(train_dataset)\n",
    "    valid_dataset_distribute = strategy.experimental_distribute_dataset(valid_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op RandomUniform in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Sub in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Mul in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Add in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op Fill in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n"
     ]
    }
   ],
   "source": [
    "with strategy.scope():\n",
    "    model = keras.models.Sequential()\n",
    "    model.add(keras.layers.Conv2D(filters=32, kernel_size=3, padding='same',\n",
    "                                  activation='relu', input_shape=(28,28,1)))\n",
    "    model.add(keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "    model.add(keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "    model.add(keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "\n",
    "    model.add(keras.layers.Flatten())\n",
    "    model.add(keras.layers.Dense(128, activation='relu'))\n",
    "    model.add(keras.layers.Dense(10, activation='softmax'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 28, 28, 32)        320       \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 28, 28, 32)        9248      \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496     \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 14, 14, 64)        36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 7, 7, 128)         73856     \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 7, 7, 128)         147584    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 (None, 3, 3, 128)         0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 1152)              0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               147584    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 435,306\n",
      "Trainable params: 435,306\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op OptimizeDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op ModelDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op IteratorGetNextSync in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "Executing op __inference_initialize_variables_1046 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "Executing op __inference_distributed_train_step_1461 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op DeleteIterator in device /job:localhost/replica:0/task:0/device:CPU:0\n"
     ]
    },
    {
     "ename": "UnknownError",
     "evalue": "2 root error(s) found.\n  (0) Unknown:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[node sequential/conv2d/Conv2D (defined at f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\framework\\ops.py:1751) ]]\n\t [[truediv/_22]]\n  (1) Unknown:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[node sequential/conv2d/Conv2D (defined at f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\framework\\ops.py:1751) ]]\n0 successful operations.\n1 derived errors ignored. [Op:__inference_distributed_train_step_1461]\n\nFunction call stack:\ndistributed_train_step -> distributed_train_step\n",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mUnknownError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-8-98f3e9342012>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     52\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtrain_dataset\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     53\u001b[0m             \u001b[0mstart_time\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 54\u001b[1;33m             \u001b[0mtotal_loss\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mdistributed_train_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     55\u001b[0m             \u001b[0mrun_time\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[1;31m# 计算时间\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     56\u001b[0m             \u001b[0mnum_batches\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    455\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    456\u001b[0m     \u001b[0mtracing_count\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_tracing_count\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 457\u001b[1;33m     \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    458\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mtracing_count\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_tracing_count\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    459\u001b[0m       \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_counter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcalled_without_tracing\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m_call\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    518\u001b[0m         \u001b[1;31m# Lifting succeeded, so variables are initialized and we can run the\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    519\u001b[0m         \u001b[1;31m# stateless function.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 520\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_stateless_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    521\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    522\u001b[0m       \u001b[0mcanon_args\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcanon_kwds\u001b[0m \u001b[1;33m=\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1821\u001b[0m     \u001b[1;34m\"\"\"Calls a graph function specialized to the inputs.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1822\u001b[0m     \u001b[0mgraph_function\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_maybe_define_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1823\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_filtered_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# pylint: disable=protected-access\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1824\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1825\u001b[0m   \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_filtered_call\u001b[1;34m(self, args, kwargs)\u001b[0m\n\u001b[0;32m   1139\u001b[0m          if isinstance(t, (ops.Tensor,\n\u001b[0;32m   1140\u001b[0m                            resource_variable_ops.BaseResourceVariable))),\n\u001b[1;32m-> 1141\u001b[1;33m         self.captured_inputs)\n\u001b[0m\u001b[0;32m   1142\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1143\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m_call_flat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcaptured_inputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcancellation_manager\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_call_flat\u001b[1;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[0;32m   1222\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mexecuting_eagerly\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1223\u001b[0m       flat_outputs = forward_function.call(\n\u001b[1;32m-> 1224\u001b[1;33m           ctx, args, cancellation_manager=cancellation_manager)\n\u001b[0m\u001b[0;32m   1225\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1226\u001b[0m       \u001b[0mgradient_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_delayed_rewrite_functions\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mregister\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36mcall\u001b[1;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[0;32m    509\u001b[0m               \u001b[0minputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    510\u001b[0m               \u001b[0mattrs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"executor_type\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexecutor_type\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"config_proto\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 511\u001b[1;33m               ctx=ctx)\n\u001b[0m\u001b[0;32m    512\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    513\u001b[0m           outputs = execute.execute_with_cancellation(\n",
      "\u001b[1;32mf:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\eager\\execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[1;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[0;32m     65\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     66\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 67\u001b[1;33m     \u001b[0msix\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mraise_from\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcore\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcode\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     68\u001b[0m   \u001b[1;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     69\u001b[0m     keras_symbolic_tensors = [\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\six.py\u001b[0m in \u001b[0;36mraise_from\u001b[1;34m(value, from_value)\u001b[0m\n",
      "\u001b[1;31mUnknownError\u001b[0m: 2 root error(s) found.\n  (0) Unknown:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[node sequential/conv2d/Conv2D (defined at f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\framework\\ops.py:1751) ]]\n\t [[truediv/_22]]\n  (1) Unknown:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[node sequential/conv2d/Conv2D (defined at f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\framework\\ops.py:1751) ]]\n0 successful operations.\n1 derived errors ignored. [Op:__inference_distributed_train_step_1461]\n\nFunction call stack:\ndistributed_train_step -> distributed_train_step\n"
     ]
    }
   ],
   "source": [
    "with strategy.scope():\n",
    "    # 1.define losses functions\n",
    "    loss_func = keras.losses.SparseCategoricalCrossentropy(\n",
    "        reduction = keras.losses.Reduction.NONE)\n",
    "    def compute_loss(labels, predictions):\n",
    "        per_replica_loss = loss_func(labels, predictions)\n",
    "        return tf.nn.compute_average_loss(per_replica_loss, global_batch_size=batch_size)\n",
    "    \n",
    "    test_loss = keras.metrics.Mean(name = 'test_loss')\n",
    "    train_accuracy = keras.metrics.SparseCategoricalAccuracy(name = 'train_accuracy')\n",
    "    test_accuracy = keras.metrics.SparseCategoricalAccuracy(name = 'test_accuracy')\n",
    "\n",
    "    optimizer = keras.optimizers.SGD(lr = 0.01)\n",
    "\n",
    "    # 有图计算：正向计算和反向传播，所以添加tf.function加速\n",
    "    # 2.define function train_step\n",
    "    def train_step(inputs):\n",
    "        images, labels = inputs\n",
    "        with tf.GradientTape() as tape:\n",
    "            predictions = model(images, training = True)\n",
    "            loss = compute_loss(labels, predictions)\n",
    "        gradients = tape.gradient(loss, model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "        train_accuracy.update_state(labels, predictions)\n",
    "        return loss\n",
    "    \n",
    "    # 分布式环境下，聚合loss\n",
    "    @tf.function\n",
    "    def distributed_train_step(inputs):\n",
    "        per_replica_average_loss = strategy.experimental_run_v2(train_step, args = (inputs,))\n",
    "        return strategy.reduce(\n",
    "            tf.distribute.ReduceOp.SUM, per_replica_average_loss, axis = None)\n",
    "\n",
    "    # 3.define function test_step\n",
    "    def test_step(inputs):\n",
    "        images, labels = inputs\n",
    "        predictions = model(images)\n",
    "        t_loss = loss_func(labels, predictions)\n",
    "        test_loss.update_state(t_loss)\n",
    "        test_accuracy.update_state(labels, predictions)\n",
    "\n",
    "    @tf.function\n",
    "    def distributed_test_step(inputs):\n",
    "        strategy.experimental_run_v2(test_step, args = (inputs,))\n",
    "\n",
    "    # 4.for-loop training loop\n",
    "    epochs = 10\n",
    "    for epoch in range(epochs):\n",
    "        total_loss = 0.0\n",
    "        num_batches = 0\n",
    "        # 在训练集训练\n",
    "        for x in train_dataset:\n",
    "            start_time = time.time()\n",
    "            total_loss += distributed_train_step(x)\n",
    "            run_time = time.time() - start_time # 计算时间\n",
    "            num_batches += 1\n",
    "            print('\\rtotal_loss:%3.3f, num_batches:%3.3f, average_loss:%3.3f, time:%3.3f' \n",
    "                  % (total_loss, num_batches, total_loss / num_batches, run_time), end='')\n",
    "        train_loss = total_loss / num_batches\n",
    "\n",
    "        # 验证集验证\n",
    "        for x in valid_dataset:\n",
    "            distributed_test_step(x)\n",
    "\n",
    "        # 打印日志\n",
    "        print('\\rEpoch:%d, Loss:%3.3f, Acc:%3.3f, Val_Loss:%3.3f, Val_Acc:%3.3f' \n",
    "              % (epoch + 1, train_loss, train_accuracy.result(), \n",
    "                 test_loss.result(), test_accuracy.result()))\n",
    "        # 清空累计值\n",
    "        test_loss.reset_states()\n",
    "        train_accuracy.reset_states()\n",
    "        test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
